{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "import torch\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('../common/')\n",
    "\n",
    "from evaluator import *\n",
    "from utils import *\n",
    "from usad_gan_only import  *\n",
    "gpu_choice =get_default_device().index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_windows(data, window_size):\n",
    "    windows = []\n",
    "    length  = data.shape[0]\n",
    "    for i in range(length):\n",
    "        if length - i >= window_size:\n",
    "            window = data[i:i+window_size]\n",
    "        else:\n",
    "            window = data[i-window_size+1:i+1]\n",
    "        windows.append(window)    \n",
    "    return np.array(windows)\n",
    "\n",
    "def load_dataset(dataset, machine, window_size):\n",
    "    folder = os.path.join(\"../../processed\", dataset)\n",
    "    if not os.path.exists(folder):\n",
    "        raise Exception(\"Processed Data not found.\")\n",
    "    loader = []\n",
    "    for file in [\"train\", \"test\", \"labels\"]:\n",
    "        file = machine + \"_\" + file\n",
    "        loader.append(np.load(os.path.join(folder, f\"{file}.npy\")))\n",
    "    ## 准备数据\n",
    "    train_data = convert_to_windows(loader[0], window_size)\n",
    "    test_data = convert_to_windows(loader[1], window_size)\n",
    "    labels = np.zeros((loader[2].shape[0], 1))\n",
    "    for i, row in enumerate(loader[2]):\n",
    "        if np.any(row == 1):\n",
    "            labels[i] = 1   \n",
    "        \n",
    "    return (train_data, test_data, labels)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=1e-4\n",
    "num_epochs=100\n",
    "window_size = 5\n",
    "z_dim = 38\n",
    "batch_size = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total datas: 28\n"
     ]
    }
   ],
   "source": [
    "origin_file_list = os.listdir(\"../data/SMD/interpretation_label\")\n",
    "file_name_list = []\n",
    "for origin_file in origin_file_list:\n",
    "    file_name_list.append(origin_file[:-4])\n",
    "file_name_list.sort()\n",
    "\n",
    "datas = {}\n",
    "for file_name in file_name_list:\n",
    "    train_data, test_data, labels = load_dataset(\"SMD\", file_name,window_size)\n",
    "    datas[file_name] = (train_data, test_data, labels)\n",
    "    # print(f\"train_data shape: {train_data.shape}\")\n",
    "    # print(f\"test_data shape: {test_data.shape}\")\n",
    "    # print(f\"labels shape: {labels.shape}\")\n",
    "print(f\"total datas: {len(datas)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:09<00:00,  1.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:59<00:00,  1.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:07<00:00,  1.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:08<00:00,  1.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:57<00:00,  1.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:56<00:00,  1.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:56<00:00,  1.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-1-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:56<00:00,  1.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:55<00:00,  1.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:58<00:00,  1.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:57<00:00,  1.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:57<00:00,  1.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:55<00:00,  1.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:10<00:00,  1.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:55<00:00,  1.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:55<00:00,  1.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-2-9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:10<00:00,  1.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:12<00:00,  1.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:59<00:00,  1.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:14<00:00,  1.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:58<00:00,  1.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:57<00:00,  1.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:53<00:00,  1.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:54<00:00,  1.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:08<00:00,  1.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:09<00:00,  1.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:08<00:00,  1.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "machine-3-9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:10<00:00,  1.42it/s]\n"
     ]
    }
   ],
   "source": [
    "train_gaussian_percentage = 0.25\n",
    "\n",
    "results = {}\n",
    "for machine in file_name_list:\n",
    "    print(machine) \n",
    "    train_window,test_window,labels = datas[machine]\n",
    "    train_dataset = torch.from_numpy(train_window).float().view([train_window.shape[0], \n",
    "                                     train_window.shape[1]*train_window.shape[2]])\n",
    "    test_dataset = torch.from_numpy(test_window).float().view([test_window.shape[0], \n",
    "                                     test_window.shape[1]*test_window.shape[2]])\n",
    "    model = UsadGANModel(train_window.shape[2]*window_size, z_dim)\n",
    "    model.to(get_default_device())\n",
    "    indices = np.random.permutation(len(train_dataset))\n",
    "    split_point = int(train_gaussian_percentage * len(train_dataset))\n",
    "\n",
    "    \n",
    "    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True,\n",
    "                              sampler=SubsetRandomSampler(indices[:-split_point]), pin_memory=False)\n",
    "    val_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True,\n",
    "                                       sampler=SubsetRandomSampler(indices[-split_point:]), pin_memory=False)\n",
    "    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)\n",
    "\n",
    "\n",
    "    history = model.fit(train_loader=train_loader,validation_loader=val_loader,epochs=num_epochs)\n",
    "    pred= model.predict_prob(test_loader)\n",
    "    result = bf_search(labels.flatten(), pred, verbose=False)\n",
    "    results[machine] =result\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision mean: 0.897\n",
      "recall mean: 0.8475\n",
      "f1 mean: 0.8513\n",
      "f1* mean: 0.8715\n"
     ]
    }
   ],
   "source": [
    "f1_scores = [item['f1-score'] for item in list(results.values())]\n",
    "precisions = [item['precision'] for item in list(results.values())]\n",
    "recalls = [item['recall'] for item in list(results.values())]\n",
    "TNs = [item['TN'] for item in list(results.values())]\n",
    "TPs = [item['TP'] for item in list(results.values())]\n",
    "FNs = [item['FN'] for item in list(results.values())]\n",
    "FPs = [item['FP'] for item in list(results.values())]\n",
    "f1_mean = round(np.mean(f1_scores).item(),4)\n",
    "precision_mean = round(np.mean(precisions),4)\n",
    "recall_mean = round(np.mean(recalls),4)\n",
    "f1_star = round(2 * precision_mean * recall_mean / (precision_mean + recall_mean + 0.00001),4)\n",
    "print(f\"precision mean: {precision_mean}\")\n",
    "print(f\"recall mean: {recall_mean}\")\n",
    "print(f\"f1 mean: {f1_mean}\")\n",
    "print(f\"f1* mean: {f1_star}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SMAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=1e-4\n",
    "num_epochs = 100\n",
    "window_size = 5\n",
    "z_dim = 55\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "origin_file_list = os.listdir(\"../processed/SMAP\")\n",
    "file_name_set = set()\n",
    "for origin_file in origin_file_list:\n",
    "    file_name_set.add(origin_file.split(\"_\")[0])\n",
    "file_name_list = list(file_name_set)\n",
    "file_name_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "datas = {}\n",
    "for file_name in file_name_list:\n",
    "    train_data, test_data, labels = load_dataset(\"SMAP\", file_name, window_size)\n",
    "    datas[file_name] = (train_data, test_data, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:26<00:00,  1.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:21<00:00,  1.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:33<00:00,  1.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:29<00:00,  1.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:21<00:00,  4.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:21<00:00,  4.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:29<00:00,  1.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:23<00:00,  4.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A-9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:22<00:00,  4.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "B-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:16<00:00,  1.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:26<00:00,  1.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:19<00:00,  1.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:09<00:00, 10.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:27<00:00,  1.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:29<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:26<00:00,  1.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:20<00:00,  1.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:19<00:00,  1.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:20<00:00,  1.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:19<00:00,  1.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:17<00:00,  1.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:36<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:35<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:35<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:33<00:00,  1.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:29<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:27<00:00,  1.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:28<00:00,  1.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:24<00:00,  1.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:28<00:00,  1.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:26<00:00,  1.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:29<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:36<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E-9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:34<00:00,  1.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:36<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:35<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:36<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:34<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:25<00:00,  1.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:33<00:00,  1.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:24<00:00,  1.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:41<00:00,  1.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "G-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:23<00:00,  1.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:36<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:34<00:00,  1.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:37<00:00,  1.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:28<00:00,  1.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:34<00:00,  1.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "R-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:43<00:00,  1.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "S-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:41<00:00,  1.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:32<00:00,  1.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:30<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:32<00:00,  1.08it/s]\n"
     ]
    }
   ],
   "source": [
    "train_gaussian_percentage = 0.25\n",
    "\n",
    "results = {}\n",
    "for machine in file_name_list:\n",
    "    print(machine) \n",
    "    train_window,test_window,labels = datas[machine]\n",
    "    train_dataset = torch.from_numpy(train_window).float().view([train_window.shape[0], \n",
    "                                     train_window.shape[1]*train_window.shape[2]])\n",
    "    test_dataset = torch.from_numpy(test_window).float().view([test_window.shape[0], \n",
    "                                     test_window.shape[1]*test_window.shape[2]])\n",
    "       # 每个模型单独训练\n",
    "    model = UsadGANModel(train_window.shape[2]*window_size, z_dim)\n",
    "    model.to(get_default_device())\n",
    "    indices = np.random.permutation(len(train_dataset))\n",
    "    split_point = int(train_gaussian_percentage * len(train_dataset))\n",
    "\n",
    "    \n",
    "    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True,\n",
    "                              sampler=SubsetRandomSampler(indices[:-split_point]), pin_memory=False)\n",
    "    val_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True,\n",
    "                                       sampler=SubsetRandomSampler(indices[-split_point:]), pin_memory=False)\n",
    "    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)\n",
    "\n",
    "\n",
    "    history = model.fit(train_loader=train_loader,validation_loader=val_loader,epochs=num_epochs)\n",
    "    pred= model.predict_prob(test_loader)\n",
    "    result = bf_search(labels.flatten(), pred, verbose=False)\n",
    "    results[machine] =result\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision mean: 0.7708\n",
      "recall mean: 0.8861\n",
      "f1 mean: 0.8073\n",
      "f1* mean: 0.8244\n"
     ]
    }
   ],
   "source": [
    "f1_scores = [item['f1-score'] for item in list(results.values())]\n",
    "precisions = [item['precision'] for item in list(results.values())]\n",
    "recalls = [item['recall'] for item in list(results.values())]\n",
    "TNs = [item['TN'] for item in list(results.values())]\n",
    "TPs = [item['TP'] for item in list(results.values())]\n",
    "FNs = [item['FN'] for item in list(results.values())]\n",
    "FPs = [item['FP'] for item in list(results.values())]\n",
    "f1_mean = round(np.mean(f1_scores).item(),4)\n",
    "precision_mean = round(np.mean(precisions),4)\n",
    "recall_mean = round(np.mean(recalls),4)\n",
    "f1_star = round(2 * precision_mean * recall_mean / (precision_mean + recall_mean + 0.00001),4)\n",
    "print(f\"precision mean: {precision_mean}\")\n",
    "print(f\"recall mean: {recall_mean}\")\n",
    "print(f\"f1 mean: {f1_mean}\")\n",
    "print(f\"f1* mean: {f1_star}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MSL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=1e-4\n",
    "num_epochs = 100\n",
    "window_size = 5\n",
    "z_dim = 33\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "origin_file_list = os.listdir(\"../processed/MSL\")\n",
    "file_name_set = set()\n",
    "for origin_file in origin_file_list:\n",
    "    file_name_set.add(origin_file.split(\"_\")[0])\n",
    "file_name_list = list(file_name_set)\n",
    "file_name_list.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "datas = {}\n",
    "for file_name in file_name_list:\n",
    "    train_data, test_data, labels = load_dataset(\"MSL\", file_name, window_size)\n",
    "    datas[file_name] = (train_data, test_data, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:17<00:00,  1.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:25<00:00,  3.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:58<00:00,  1.18s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:05<00:00,  1.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "D-16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:44<00:00,  2.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:08<00:00,  1.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:22<00:00,  1.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:19<00:00,  1.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:51<00:00,  1.11s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:12<00:00,  1.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:13<00:00,  1.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:04<00:00,  1.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:05<00:00,  1.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:01<00:00,  1.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:48<00:00,  2.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:52<00:00,  1.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [02:11<00:00,  1.31s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [02:08<00:00,  1.29s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:29<00:00,  1.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P-15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:52<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "S-2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:28<00:00,  3.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:34<00:00,  2.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:33<00:00,  2.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:10<00:00,  1.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:10<00:00,  1.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:22<00:00,  4.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T-9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:12<00:00,  7.94it/s]\n"
     ]
    }
   ],
   "source": [
    "train_gaussian_percentage = 0.25\n",
    "\n",
    "results = {}\n",
    "for machine in file_name_list:\n",
    "    print(machine) \n",
    "    train_window,test_window,labels = datas[machine]\n",
    "    train_dataset = torch.from_numpy(train_window).float().view([train_window.shape[0], \n",
    "                                     train_window.shape[1]*train_window.shape[2]])\n",
    "    test_dataset = torch.from_numpy(test_window).float().view([test_window.shape[0], \n",
    "                                     test_window.shape[1]*test_window.shape[2]])\n",
    "       # 每个模型单独训练\n",
    "    model = UsadGANModel(train_window.shape[2]*window_size, z_dim)\n",
    "    model.to(get_default_device())\n",
    "    indices = np.random.permutation(len(train_dataset))\n",
    "    split_point = int(train_gaussian_percentage * len(train_dataset))\n",
    "\n",
    "    \n",
    "    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True,\n",
    "                              sampler=SubsetRandomSampler(indices[:-split_point]), pin_memory=False)\n",
    "    val_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True,\n",
    "                                       sampler=SubsetRandomSampler(indices[-split_point:]), pin_memory=False)\n",
    "    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)\n",
    "\n",
    "\n",
    "    history = model.fit(train_loader=train_loader,validation_loader=val_loader,epochs=num_epochs)\n",
    "    pred= model.predict_prob(test_loader)\n",
    "    result = bf_search(labels.flatten(), pred, verbose=False)\n",
    "    results[machine] =result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision mean: 0.8399\n",
      "recall mean: 0.9655\n",
      "f1 mean: 0.8781\n",
      "f1* mean: 0.8983\n"
     ]
    }
   ],
   "source": [
    "f1_scores = [item['f1-score'] for item in list(results.values())]\n",
    "precisions = [item['precision'] for item in list(results.values())]\n",
    "recalls = [item['recall'] for item in list(results.values())]\n",
    "TNs = [item['TN'] for item in list(results.values())]\n",
    "TPs = [item['TP'] for item in list(results.values())]\n",
    "FNs = [item['FN'] for item in list(results.values())]\n",
    "FPs = [item['FP'] for item in list(results.values())]\n",
    "f1_mean = round(np.mean(f1_scores).item(),4)\n",
    "precision_mean = round(np.mean(precisions),4)\n",
    "recall_mean = round(np.mean(recalls),4)\n",
    "f1_star = round(2 * precision_mean * recall_mean / (precision_mean + recall_mean + 0.00001),4)\n",
    "print(f\"precision mean: {precision_mean}\")\n",
    "print(f\"recall mean: {recall_mean}\")\n",
    "print(f\"f1 mean: {f1_mean}\")\n",
    "print(f\"f1* mean: {f1_star}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
